{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

{# ------------------------------------------------------------------------- #}
{# Utilities for C++ code generation templates.                              #}
{# ------------------------------------------------------------------------- #}

{# Vector type of a given dimension
 #
 # Args:
 #     dim (int):
 #}
{%- macro vector_type(dim, scalar_type="Scalar") -%}
Eigen::Matrix<{{ scalar_type }}, {{ dim }}, 1>
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Matrix type of a given dimension
 #
 # Args:
 #     rows (int):
 #     cols (int):
 #}
{%- macro matrix_type(rows, cols, scalar_type="Scalar") -%}
Eigen::Matrix<{{ scalar_type }}, {{ rows }}, {{ cols }}>
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Convert a class to the emitted string
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str): Optional name in case type is a generated struct
 #}
{%- macro format_typename(T_or_value, name, scalar_type="Scalar") %}
    {%- set T = typing_util.get_type(T_or_value) -%}
    {%- if T.__name__ == 'Symbol' or is_symbolic(T_or_value) or T.__name__ == 'float' -%}
        {%- if issubclass(T, DataBuffer) -%}
            {%- if spec.config.databuffer_type is not none -%}
            {{ spec.config.databuffer_type }}* const
            {%- else -%}
            {{ scalar_type }}* const
            {%- endif -%}
        {%- else -%}
            {{ scalar_type }}
        {%- endif -%}
    {%- elif T.__name__ == 'NoneType' -%}
        void
    {%- elif issubclass(T, Matrix) -%}
        {%- if spec is defined and name in spec.sparse_mat_data -%}
        Eigen::SparseMatrix<{{ scalar_type }}>
        {%- else -%}
        {{ matrix_type(T_or_value.shape[0], T_or_value.shape[1], scalar_type) }}
        {%- endif -%}
    {%- elif issubclass(T, Values) -%}
        {{ spec.namespaces_dict[name] }}::{{ spec.typenames_dict[name] }}
    {%- elif is_sequence(T_or_value) -%}
        std::array<{{ format_typename(T_or_value[0], name, scalar_type) }}, {{ T_or_value | length }}>
    {%- else -%}
        {%- if "geo" in T.__module__ or "cam" in T.__module__ -%}
        sym::
        {%- endif -%}
        {{- T.__name__ -}}<{{ scalar_type }}>
    {%- endif -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Get the type of the object in the ouput Values with key given by spec.return_key
 #
 # Args:
 #     spec (Codegen):
 #}
{%- macro get_return_type(spec, scalar_type="Scalar") %}
    {%- if spec.return_key is not none -%}
        {{ format_typename(spec.outputs[spec.return_key], spec.return_key, scalar_type) }}
    {%- else -%}
        void
    {%- endif -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

 {# Format function docstring
 #
 # Args:
 #     docstring (str):
 #}
{% macro print_docstring(docstring) %}
{% if docstring %}

/**
{% for line in docstring.lstrip().split('\n') %}
 *{{ ' {}'.format(line).rstrip() }}
{% endfor %}
 */
{%- endif %}
{% endmacro %}

{# ------------------------------------------------------------------------- #}

{# Format function input argument
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str):
 #}
{%- macro format_input_arg(T_or_value, name, scalar_type="Scalar") %}
    {%- set T = typing_util.get_type(T_or_value) -%}
    {% if T.__name__ == "Symbol" or T == DataBuffer-%}
        {#- Scalar type is just const -#}
        const {{ format_typename(T_or_value, "", scalar_type) }} {{ name }}
    {%- else -%}
        {#- Otherwise assume const reference -#}
        const {{ format_typename(T_or_value, name, scalar_type) }}& {{ name }}
    {%- endif %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Format function pointer argument
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str):
 #     add_default (bool): Include a default to nullptr?
 #}
{%- macro format_pointer_arg(T_or_value, name, add_default, scalar_type="Scalar") %}
    {{- format_typename(T_or_value, name, scalar_type) -}}* const {{ name }}{% if add_default %} = nullptr{% endif %}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Format Eigen::Map argument
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str):
 #}
{%- macro format_map_arg(T_or_value, name, scalar_type="Scalar") %}
    Eigen::Map<{{- format_typename(T_or_value, name, scalar_type) -}}> {{ name }}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate input arguments declaration.
 #
 # Args:
 #     spec (Codegen):
 #     is_declaration (bool): Is this a declaration as opposed to a definition? (declarations
 #         include default arguments)
 #}
{%- macro input_args_declaration(spec, is_declaration, scalar_type="Scalar") %}
    {%- for name, type in spec.inputs.items() -%}
        {%- if name != "self" -%}
            {{ format_input_arg(type, name, scalar_type) }}
            {%- if not loop.last
                or spec.outputs.items() | length > 1
                or (spec.outputs.items() | length == 1 and spec.return_key is none) -%}
            , {% endif -%}
        {%- endif -%}
    {%- endfor -%}
    {%- for name, type in spec.outputs.items() -%}
        {%- if name != spec.return_key -%}
            {# Should this be per-output? #}
            {% if issubclass(typing_util.get_type(type), Matrix) and spec.config.use_maps_for_outputs %}
                {{- format_map_arg(type, name, scalar_type) -}}
            {% else %}
                {{- format_pointer_arg(type, name, is_declaration, scalar_type) -}}
            {% endif %}
            {%- if not loop.last -%}
                {%- if not (loop.revindex0 == 1 and loop.nextitem[0] == spec.return_key) -%}
                , {% endif -%}
            {% endif -%}
        {%- endif -%}
    {%- endfor -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Turn a function declaration into a method declaration by appending const if is_instance_method
 # or prepending static if it's a declaration (as opposed to a definition).
 #
 # Args:
 #     declaration_base (str): The function declaration we want to decorate with const or static
 #         as appropriate.
 #     is_instance_method (bool): Is this an instance method that we should append const to as
 #         opposed to a static method?
 #     is_declaration (bool): Is this a declaration as opposed to a definition? (only declarations
 #         are marked as static)
 #}
{%- macro make_method_declaration(declaration_base, is_instance_method, is_declaration) -%}
{% if is_instance_method -%}
    {{ declaration_base }} const
{% elif is_declaration -%}
    static {{ declaration_base }}
{% else %}
    {{ declaration_base }}
{% endif %}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate function declaration
 #
 # Args:
 #     spec (Codegen):
 #     is_declaration (bool): Is this a declaration as opposed to a definition? (declarations
 #         include default arguments)
 #}
{%- macro function_declaration(spec, is_declaration, scalar_type="Scalar") -%}
{% set name = python_util.snakecase_to_camelcase(spec.name) %}
{{ get_return_type(spec, scalar_type) }} {{ name }}({{- input_args_declaration(spec, is_declaration, scalar_type) -}})
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate method declaration
 #
 # Args:
 #     spec (Codegen):
 #     is_declaration (bool): Is this a declaration as opposed to a definition? (declarations
 #         include default arguments)
 #
 # Precondition:
 #     "self" in spec.inputs iff spec specifies an instance method
 #}
{%- macro method_declaration(spec, is_declaration) -%}
{{
    make_method_declaration(
        function_declaration(spec, is_declaration),
        "self" in spec.inputs,
        is_declaration
    )
}}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate function declaration with custom namespace
 #
 # Args:
 #     spec (Codegen):
 #     namespace (str):
 #     is_declaration (bool): Is this a declaration as opposed to a definition? (declarations
 #         include default arguments)
 #}
{%- macro function_declaration_custom_namespace(spec, namespace, is_declaration, scalar_type="Scalar") -%}
{% set name = python_util.snakecase_to_camelcase(spec.name) %}
{{ get_return_type(spec, scalar_type) }} {{ namespace }}::{{ name }}({{- input_args_declaration(spec, is_declaration, scalar_type) -}})
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate method declaration with custom namespace
 #
 # Args:
 #     spec (Codegen):
 #     namespace (str):
 #     is_declaration (bool): Is this a declaration as opposed to a definition? (declarations
 #         include default arguments)
 #
 # Precondition:
 #     "self" in spec.inputs iff spec specifies an instance method
 #}
{%- macro method_declaration_custom_namespace(spec, namespace, is_declaration) -%}
{{
    make_method_declaration(
        function_declaration_custom_namespace(spec, namespace, is_declaration),
        "self" in spec.inputs,
        is_declaration
    )
}}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Initialize sparse matrix if appropriate
 #
 # Args:
 #     name (str): Key of sparse matrix
 #     sparse_format (CSCFormat): The python sparse representation of the matrix to be initialized
 #     spec (Codegen): Codegen specification containing sparse_mat_data
 #}
{%- macro sparse_matrix_init(name, sparse_format, spec, scalar_type="Scalar") -%}
    static constexpr int kRows_{{ name }} = {{ sparse_format.kRows }};
    static constexpr int kCols_{{ name }} = {{ sparse_format.kCols }};
    static constexpr int kNumNonZero_{{ name }} = {{ sparse_format.kNumNonZero }};
    static constexpr int kColPtrs_{{ name }}[] = {
    {%- for i in sparse_format.kColPtrs -%}
        {{ i }}{%- if not loop.last -%}, {% endif -%}
    {%- endfor -%} };
    static constexpr int kRowIndices_{{ name }}[] = {
    {%- for i in sparse_format.kRowIndices -%}
        {{ i }}{%- if not loop.last -%}, {% endif -%}
    {%- endfor -%} };

    {%- macro construct_sparse_map() -%}
    Eigen::Map<const Eigen::SparseMatrix<{{ scalar_type }}>>(
        kRows_{{ name }},
        kCols_{{ name }},
        kNumNonZero_{{ name }},
        kColPtrs_{{ name }},
        kRowIndices_{{ name }},
        {{ name }}_empty_value_ptr
    );
    {%- endmacro -%}

    {%- if name == spec.return_key %}

    {{ scalar_type }} {{ name }}_empty_value_ptr[{{ sparse_format.kNumNonZero }}];
    Eigen::SparseMatrix<{{ scalar_type }}> {{ name }} = {{ construct_sparse_map() }}
    {{ scalar_type }}* {{ name }}_value_ptr = {{ name }}.valuePtr();
    {% else -%}

    if ({{ name }}->nonZeros() != {{ sparse_format.kNumNonZero }}
        || {{ name }}->outerSize() != {{ sparse_format.kCols }}
        || {{ name }}->innerSize() != {{ sparse_format.kRows }}
        || !{{ name }}->isCompressed()) {
        // Matrix does not have the expected layout, create a correctly initialized sparse matrix
        {{ scalar_type }} {{ name }}_empty_value_ptr[{{ sparse_format.kNumNonZero }}];
        *{{ name }} = {{ construct_sparse_map() }}
    }
    {{ scalar_type }}* {{ name }}_value_ptr = {{ name }}->valuePtr();

    {%- endif %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Helper to generate code to fill out an output object, either returned or as an output argument pointer
 #
 # Args:
 #     name (str): Name of the output object
 #     type (type): Type of the output object
 #     terms (List[Tuple[str]]): List of output terms for this object
 #     spec (Codegen):
 #}
{% macro format_output_dense(name, type, terms, spec, scalar_type="Scalar") -%}
{% set T = typing_util.get_type(type) %}
{% set created_output_temporary = False %}
{% if issubclass(T, Matrix) and name != spec.return_key and spec.config.use_maps_for_outputs %}
    {# Eigen::Map outputs don't need a temporary #}
{% elif issubclass(T, Matrix) or issubclass(T, Values) or is_symbolic(type) or is_sequence(type) %}
    {% if name == spec.return_key %}
        {# Create a new object to return #}
{{ format_typename(type, name, scalar_type) }} _{{ name }};
    {% else %}
        {# Get reference to output passed by reference #}
{{ format_typename(type, name, scalar_type) }}& _{{ name }} = (*{{ name }});
    {% endif %}
{% else %}
    {# geo/cam object. Since we can't access individual element of data, create a copy #}
    {# TODO(nathan): Maybe add a [] operator to geo/cam objects so we don't have to do this? #}
    {% set dims = ops.StorageOps.storage_dim(type) %}
    {% set created_output_temporary = True %}
{{ vector_type(dims, scalar_type) }} _{{ name }};
{% endif %}

{% if issubclass(T, Matrix) and spec.config.use_maps_for_outputs and name != spec.return_key %}
{% set maybe_underscore = "" %}
{% else %}
{% set maybe_underscore = "_" %}
{% endif %}
{% set set_zero = issubclass(T, Matrix) and should_set_zero(
    spec.outputs[name], spec.config.zero_initialization_sparsity_threshold) %}
{% if set_zero %}
{{ maybe_underscore }}{{ name }}.setZero();
{% endif %}

{% for lhs, rhs in terms %}
{% if not set_zero or rhs != "0" %}
{{ maybe_underscore }}{{ lhs }} = {{ rhs }};
{% endif %}
{% endfor %}
{% if name != spec.return_key and created_output_temporary %}
    {% set typename = format_typename(type, name, scalar_type) %}
    {% set normalize = "" if spec.config.normalize_results else ", /* normalize */ false" %}

*{{ name }} = {{ typename }}(_{{ name }}{{ normalize }});
{% endif %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Helper to generate all output arguments #}
{% macro format_outputs_dense_and_sparse(spec, scalar_type="Scalar") -%}
    {% for name, type, terms in spec.print_code_results.dense_terms %}
        {% if name == spec.return_key %}
    {{ format_output_dense(name, type, terms, spec, scalar_type) | indent(width=4) | trim }}
        {% else %}
            {% if issubclass(typing_util.get_type(type), Matrix) and spec.config.use_maps_for_outputs %}
    if ( {{ name }}.data() != nullptr ) {
            {% else %}
    if ( {{ name }} != nullptr ) {
            {% endif %}
        {{ format_output_dense(name, type, terms, spec, scalar_type) | indent(width=8) | trim }}
    }
        {% endif %}

    {% endfor %}
    {% for name, type, terms in spec.print_code_results.sparse_terms %}
        {% if name == spec.return_key %}
    {{ sparse_matrix_init(name, spec.sparse_mat_data[name], spec, scalar_type) }}

            {% for lhs, rhs in terms %}
    {{ lhs }} = {{ rhs }};
            {% endfor %}
        {% else %}
    if ( {{ name }} != nullptr ) {
        {{ sparse_matrix_init(name, spec.sparse_mat_data[name], spec, scalar_type) | indent(width=4) }}

            {% for lhs, rhs in terms %}
        {{ lhs }} = {{ rhs }};
            {% endfor %}
    }
        {% endif %}

    {% endfor %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate inner code for computing the given expression.
 #
 # Args:
 #     spec (Codegen):
 #}
{% macro expr_code(spec, scalar_type="Scalar") -%}
    // Total ops: {{ spec.total_ops() }}

    {% if spec.unused_arguments %}
    // Unused inputs
    {% for arg in spec.unused_arguments %}
    (void){{ arg }};
    {% endfor %}

    {% endif %}
    // Input arrays
    {% for name, type in spec.inputs.items() %}
        {% set T = typing_util.get_type(type) %}
        {% if name not in spec.unused_arguments and not issubclass(T, Values) and not issubclass(T, Matrix) and not is_symbolic(type) and not is_sequence(type) %}
            {% if name == "self" %}
    const {{ vector_type(ops.StorageOps.storage_dim(type), scalar_type) }}& _{{ name }} = Data();
            {% else %}
    const {{ vector_type(ops.StorageOps.storage_dim(type), scalar_type) }}& _{{ name }} = {{ name }}.Data();
            {% endif %}
        {% endif %}
    {% endfor %}

    // Intermediate terms ({{ spec.print_code_results.intermediate_terms | length }})
    {% for lhs, rhs in spec.print_code_results.intermediate_terms %}
    const {{ format_typename(Symbol, scalar_type) }} {{ lhs }} = {{ rhs }};
    {% endfor %}

    // Output terms ({{ spec.outputs.items() | length }})
    {{ format_outputs_dense_and_sparse(spec, scalar_type) | trim }}
    {# Populate the output vectors with the generated expressions #}
    {% for name, type in spec.outputs.items() %}
        {% set T_return = typing_util.get_type(type) %}
        {% if name == spec.return_key and T_return.__name__ != 'NoneType' %}

            {% if name in spec.sparse_mat_data %}
    return {{ name }};
            {% elif issubclass(T_return, Matrix) or issubclass(T_return, Values) or is_symbolic(type) or is_sequence(type) %}
    return _{{ name }};
            {% else %}
                {% set typename = format_typename(type, name, scalar_type) %}
                {% set normalize = "" if spec.config.normalize_results else ", /* normalize */ false" %}
    return {{ typename }}(_{{ name }}{{ normalize }});
            {% endif %}
        {% endif %}
    {% endfor %}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate an initializer expression for a sym::Key
 #
 # Args:
 #     generated_key (namedtuple(letter, sub))
 #}
{%- macro format_sym_key(generated_key) %}
    {%- if generated_key.sub is none -%}
        '{{ generated_key.letter }}'
    {%- else -%}
        {'{{ generated_key.letter }}', {{ generated_key.sub }}}
    {%- endif -%}
{% endmacro -%}
